from ModularUtils.FunctionsConstant import getdoKey
from ModularUtils.ControllerConstants import generate_permutations


class CausalGraph():

    def __init__(self, name, dag, confs, dims, num_latent):
        self.DAG_desc = name

        self.Complete_DAG_desc = name
        self.Observed_DAG = dag

        self.num_confs = len(confs.keys())
        self.Complete_DAG = {}
        for cnf in range(self.num_confs):
            self.Complete_DAG["U" + str(cnf)] = []

        self.latent_conf = {}
        for var in self.Observed_DAG:
            self.Complete_DAG[var] = []
            self.latent_conf[var] = []

        self.confTochild = confs

        for cnf in self.confTochild:
            for var in self.confTochild[cnf]:
                self.latent_conf[var].append(cnf)
                self.Complete_DAG[var].append(cnf)

        for var in self.Observed_DAG:
            self.Complete_DAG[var] = self.Complete_DAG[var] + self.Observed_DAG[var]

        self.complete_labels = list(self.Complete_DAG.keys())
        self.label_names = list(self.Observed_DAG.keys())

        self.label_dim=dims

        for cnf in self.confTochild:
            self.label_dim[cnf] = num_latent


        self.image_labels= None
        self.rep_labels= None







def set_TrueMediator(noise_states, latent_state, obs_state, Data_intervs):

    Observed_DAG = {
        'D':[],
        "U0":[],
        "I":['D'],
        "genC":['U0','I']
        }
    confTochild = {}
    label_dim = {'D':2, 'U0':4, 'I':0, "genC": 3}
    G= CausalGraph(name="TrueMediator", dag=Observed_DAG, confs=confTochild, dims=label_dim, num_latent=latent_state)
    plot_title="TrueMediator Synthetic Experiment"
    G.image_labels= ['I']
    G.rep_labels= []


    intervention_list = [{"expr":"P(genC)" ,"obs":['genC'], "inter_vars":[]}
                         # {"expr":"P(genC|do(U0,I)" ,"obs":['genC'], "inter_vars":['U0','I']}
                         ]

    for lid in range(len(intervention_list)):
        intervention_list[lid]["expr"] = getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])

    interv_queries = []
    for intervention in intervention_list:
        perms = generate_permutations([label_dim[lb] for lb in intervention["inter_vars"]])
        key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
        interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})


    cf_queries = []


    exogenous = {}
    for label in G.label_names:
        if label not in G.image_labels:
            exogenous[label] = "n" + label


    # counterfactual variables
    cflabel_names = []
    Twin_Network = {}

    cf_exogenous = {}

    cf_intervene = {}
    cf_observe = []
    cf_evidence = {}

    twin_map = {}


    noise_params = {}
    for label in Observed_DAG:
        noise_params["n" + label] = (0.5, noise_states)

    for conf in confTochild:
        noise_params[conf] = (0.1, latent_state)


    train_mech_dict={}
    # for dist in Data_intervs:
    #     comp_dict= build_compares(confTochild, Observed_DAG, label_names, list(dist.keys()))
    #     for label in label_names:
    #         if label not in train_mech_dict:
    #             train_mech_dict[label]=[]
    #
    #         mech_dict = {"parents": Observed_DAG[label], "intv": dist, "compare":comp_dict[label]}
    #         if label in image_labels:
    #             continue
    #         train_mech_dict[label].append(mech_dict)

    # train_mech_dict["I"]=[{'parents': ['D'], 'intv': {}, 'compare': ['D', 'I']}]
    # train_mech_dict["D"]=[{'parents': [], 'intv': {}, 'compare': ['D', 'I',  'C']}]
    # train_mech_dict["C"]=[{'parents': [], 'intv': {}, 'compare': ['D', 'I', 'C']}]

    train_mech_dict["D"] = [{'parents': [], 'intv': {}, 'compare': ['D']}]
    train_mech_dict["U0"] = [{'parents': [], 'intv': {}, 'compare': ['U0']}]
    train_mech_dict["I"] = [{'parents': [], 'intv': {}, 'compare': ['I']}]
    train_mech_dict["genC"] = [{'parents': ['U0','I'], 'intv': {'U0':0,'I':0}, 'compare': ['genC']}]
    #compare: joint for which variables are needed. parents: which variables i need to intervene on


    # train_mech_dict["W0"][0]['intv']= {"X0":0}
    # train_mech_dict["W1"][0]['intv']= {"X0":0, "X1":0, "X2":0}


    # print("printing")
    # for label in label_names:
    #     print(label, train_mech_dict[label])




    for label in Observed_DAG:
        if label not in G.image_labels:
            label_dim["n" + label] =  noise_states

    return G.DAG_desc, G.Complete_DAG_desc, G.Complete_DAG, G.complete_labels, G.Observed_DAG, G.label_names, G.image_labels, G.rep_labels, interv_queries, cf_queries, G.latent_conf, \
           G.confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, G.label_dim, plot_title



